# src/evaluation/metrics_calculator.py
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    classification_report,
    roc_auc_score
)
from sklearn.preprocessing import label_binarize
import logging

# Logger for this module, mainly for direct testing.
# In practice, it's better to use the logger passed from the main experiment script.
module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)

def compute_classification_metrics(
    y_true,
    y_pred,
    y_probs, # Predicted probabilities from the model (n_samples, n_classes)
    num_classes: int,
    class_names: list,
    logger_instance: logging.Logger = None # Inject an external logger
):
    """
    Calculates and logs various classification performance metrics.

    Args:
        y_true (list or np.array): True labels.
        y_pred (list or np.array): Predicted labels from the model.
        y_probs (list or np.array or None): Predicted probabilities for ROC AUC calculation. If None, ROC AUC is skipped.
        num_classes (int): The total number of classes.
        class_names (list): Names of each class for the classification report.
        logger_instance (logging.Logger, optional): The logger object to use. Falls back to the module logger if None.

    Returns:
        dict: A dictionary containing the calculated metrics.
    """
    log = logger_instance if logger_instance else module_logger
    metrics_results = {}

    # 1. Basic Classification Metrics (Accuracy, Macro/Weighted F1, Precision, Recall)
    log.info("--- Calculating Basic Classification Metrics ---")
    # Determine the averaging method for sklearn metrics
    avg_type_sklearn = 'macro' if num_classes > 2 else 'binary'

    try:
        metrics_results['accuracy'] = accuracy_score(y_true, y_pred)
        metrics_results[f'f1_score_{avg_type_sklearn}'] = f1_score(y_true, y_pred, average=avg_type_sklearn, zero_division=0)
        metrics_results[f'precision_{avg_type_sklearn}'] = precision_score(y_true, y_pred, average=avg_type_sklearn, zero_division=0)
        metrics_results[f'recall_{avg_type_sklearn}'] = recall_score(y_true, y_pred, average=avg_type_sklearn, zero_division=0)

        log.info(f"Accuracy: {metrics_results['accuracy']:.4f}")
        log.info(f"F1-Score ({avg_type_sklearn}): {metrics_results[f'f1_score_{avg_type_sklearn}']:.4f}")
        log.info(f"Precision ({avg_type_sklearn}): {metrics_results[f'precision_{avg_type_sklearn}']:.4f}")
        log.info(f"Recall ({avg_type_sklearn}): {metrics_results[f'recall_{avg_type_sklearn}']:.4f}")
    except Exception as e:
        log.error(f"Error during basic classification metrics calculation: {e}")

    # 2. Classification Report (Per-class F1, Precision, Recall, etc.)
    log.info("--- Class-wise Performance Report ---")
    try:
        # Generate the report as a string for logging
        report_str = classification_report(
            y_true,
            y_pred,
            target_names=class_names,
            zero_division=0
        )
        log.info(f"\nClassification Report:\n{report_str}")

        # Generate the report as a dictionary for parsing
        report_dict = classification_report(
            y_true,
            y_pred,
            target_names=class_names,
            output_dict=True,
            zero_division=0
        )
        
        # Store weighted average metrics
        if 'weighted avg' in report_dict:
            metrics_results['f1_score_weighted'] = report_dict['weighted avg']['f1-score']
            metrics_results['precision_weighted'] = report_dict['weighted avg']['precision']
            metrics_results['recall_weighted'] = report_dict['weighted avg']['recall']
        
        # Store per-class metrics
        for class_name in class_names:
            if class_name in report_dict:
                metrics_results[f'f1_score_{class_name}'] = report_dict[class_name]['f1-score']

    except Exception as e:
        log.error(f"Error generating classification report: {e}")

    # 3. ROC AUC Score (if predicted probabilities are provided)
    if y_probs is not None and num_classes > 1:
        log.info("--- Calculating ROC AUC Score ---")
        try:
            y_probs_array = np.array(y_probs)

            if num_classes == 2: # Binary classification
                # Use probabilities of the positive class (typically the second column)
                scores_for_roc = y_probs_array[:, 1] if y_probs_array.ndim == 2 else y_probs_array
                roc_auc = roc_auc_score(y_true, scores_for_roc)
                log.info(f"ROC AUC: {roc_auc:.4f}")
                metrics_results['roc_auc'] = roc_auc
            else: # Multi-class classification (One-vs-Rest approach)
                labels_binarized = label_binarize(y_true, classes=list(range(num_classes)))
                # Check if y_true contains only a single class, which makes OvR AUC invalid
                if labels_binarized.shape[1] == 1:
                     log.warning("ROC AUC (OvR): Cannot be computed because y_true contains only one class.")
                else:
                    roc_auc_macro_ovr = roc_auc_score(labels_binarized, y_probs_array, multi_class='ovr', average='macro')
                    roc_auc_weighted_ovr = roc_auc_score(labels_binarized, y_probs_array, multi_class='ovr', average='weighted')
                    log.info(f"ROC AUC (Macro OvR): {roc_auc_macro_ovr:.4f}")
                    log.info(f"ROC AUC (Weighted OvR): {roc_auc_weighted_ovr:.4f}")
                    metrics_results['roc_auc_macro_ovr'] = roc_auc_macro_ovr
                    metrics_results['roc_auc_weighted_ovr'] = roc_auc_weighted_ovr
        except ValueError as ve:
            log.warning(f"Could not calculate ROC AUC score: {ve}. This can happen if y_true contains only one class.")
        except Exception as e:
            log.error(f"An unexpected error occurred during ROC AUC calculation: {e}")
            
    return metrics_results